from copy import deepcopy

import numpy as np
import torch
from gymnasium.spaces import Box, Discrete

import torch.nn.functional as F
from continual_rl.policies.impala.nets import ImpalaNet
from continual_rl.utils.common_nets import ConvNet7x7, ConvNet84x84
from continual_rl.policies.arc.arc_decoder import ActionDecoder


class ActionEncoder(torch.nn.Module):
    # 行为嵌入模型，输入当前状态和下一步状态，预测动作
    def __init__(self, observation_space_shape, embedding_size):
        super().__init__()
        self.embedding_size = embedding_size
        combined_observation_size = [observation_space_shape[0] * observation_space_shape[1],
                                     observation_space_shape[2],
                                     observation_space_shape[3]]
        if observation_space_shape[-1] == 7:
            self.conv_2d = ConvNet7x7(combined_observation_size)
        else:
            self.conv_2d = ConvNet84x84(combined_observation_size) # 对于其他尺寸的图像，使用84x84的卷积网络
        self.liner = torch.nn.Linear(self.conv_2d.output_shape[0], self.embedding_size)
        self.fisher_information = None
        self.optimal_params = {}

    def forward(self, observation, next_observation):
        # 将多个frame的状态合并到通道维度
        observation = observation.view(observation.shape[0], observation.shape[1] * observation.shape[2],
                                       observation.shape[3], observation.shape[4])
        next_observation = next_observation.view(next_observation.shape[0],
                                                 next_observation.shape[1] * next_observation.shape[2],
                                                 next_observation.shape[3], next_observation.shape[4])

        redis = next_observation - observation
        x = self.conv_2d(redis)
        x = torch.nn.functional.sigmoid(x)  # 使用sigmoid激活函数，将输出限制在0-1之间
        return self.liner(x)

    def compute_fisher_information(self, data_set, device='cuda'):
        # 计算Fisher信息矩阵
        self.eval()
        fisher_information = None
        data_loader = torch.utils.data.DataLoader(data_set, batch_size=512, shuffle=True)
        for state, next_state, embeddings in data_loader:
            state, next_state, embeddings = state.to(device), next_state.to(device), embeddings.to(device)
            self.zero_grad()
            outputs, _ = self.forward(state, next_state)
            loss = F.mse_loss(outputs, embeddings)
            loss.backward()

            params = {n: p for n, p in self.named_parameters() if p.requires_grad}
            if fisher_information is None:
                fisher_information = {n: p.grad.data.clone().pow(2) for n, p in params.items()}
            else:
                for n, p in params.items():
                    fisher_information[n] += p.grad.data.clone().pow(2)

        for n in fisher_information:
            fisher_information[n] /= len(data_loader)

        self.fisher_information = fisher_information
        self.optimal_params = {n: p.clone() for n, p in params.items()}

    def ewc_loss(self):
        # EWC正则化损失
        if self.fisher_information is None or self.optimal_params is None:
            return 0.0
        loss = 0.0
        for n, p in self.named_parameters():
            if n in self.fisher_information:
                fisher = self.fisher_information[n]
                optimal_param = self.optimal_params[n]
                loss += (fisher * (p - optimal_param).pow(2)).sum()
        return loss


class ARCImpalaNet(ImpalaNet):
    def __init__(self, observation_spaces, action_spaces, model_flags, conv_net=None):
        # embedding_action_space = Box(low=0, high=1, shape=(model_flags.embedding_size,), dtype=np.float32)
        embedding_action_space = Discrete(model_flags.embedding_size)
        # 构造动作空间，将动作嵌入的维度作为动作空间的大小
        super().__init__(observation_spaces, {0: embedding_action_space}, model_flags, conv_net)
        self.action_decoder = None  # 将动作表征映射为动作的解码器
        # ~ 其实解码器可以独立于该网络，完全交给ARCMonoBeast处理动作解码，但为了Actor的多进程采样，还是拥有一个解码器副本比较方便
        self._actual_action_spaces = action_spaces  # 实际动作空间，用于解码动作嵌入
        core_output_size = self._conv_net.output_size + model_flags.embedding_size + 1

        # 改变策略网络为输出动作嵌入
        self.policy = torch.nn.Sequential(torch.nn.Linear(core_output_size, model_flags.embedding_size),
                                          torch.nn.Sigmoid())  # 使用sigmoid激活函数，将输出限制在0-1之间

    def set_action_decoder(self, action_decoder: ActionDecoder):
        device = next(self.parameters()).device  # 获取当前模型的设备，用于设置动作解码器的设备
        # self.action_decoder = deepcopy(action_decoder).to(device)  # 在优化器创建后在设置动作解码器，这样解码器的参数就不会被优化器更新
        self.action_decoder = action_decoder.copy_for_inference().to(device)  # 仅用于推理，不需要梯度
        # 设置动作解码器的参数不可更新
        for param in self.action_decoder.parameters():
            param.requires_grad = False

    def update_action_decoder(self, action_decoder: ActionDecoder):
        # 更新动作解码器的参数
        self.action_decoder.load_state_dict(action_decoder.state_dict())

    def forward(self, inputs, action_space_id, core_state=()):
        # * 和父类区别在于：1. 输入中使用上一步行为的嵌入而不是one-hot编码； 2. 使用动作解码器解码策略网络的输出得到实际动作；
        # * 3. 输出中动作实际上为动作嵌入，而actual_action为解码后的动作
        x = inputs["frame"]
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = torch.flatten(x, 1, 2)  # Merge stacked frames and channels.
        x = self._conv_net(x.float())
        x = F.relu(x)

        # one_hot_last_action = F.one_hot(
        #     inputs["last_action"].view(T * B), self.num_actions
        # ).float()
        last_action_embedding = inputs["last_action"].view(T * B, -1)
        clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1).float()
        core_output = torch.cat([x, clipped_reward, last_action_embedding], dim=-1)  # 输入是当前状态，奖励和上一步动作的嵌入
        core_state = tuple()

        baseline = self.baseline(core_output)  # 价值网络输出
        embedding = self.policy(core_output)  # 策略网络输出
        # embedding = torch.flatten(embedding, 0, 1)

        current_action_space = self._actual_action_spaces[action_space_id]  # 得到实际动作空间
        policy_logits, _ = self.action_decoder(embedding, action_space=current_action_space)  # 使用动作解码器解码动作嵌入

        if self.training:
            try:
                action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
            except RuntimeError:
                print("Error policy_logits:", policy_logits)  # 对于探索后缓冲区中的数据，可能会出现nan的情况，直接跳过
        else:
            action = torch.argmax(policy_logits, dim=1)
        action = action.view(T, B)

        # policy_logits = policy_logits.view(T, B, -1)
        baseline = baseline.view(T, B, self._baseline_output_dim)
        embedding = embedding.view(T, B, self.num_actions)

        output_dict = dict(baseline=baseline[:, :, 0], action=embedding, actual_action=action)

        if self._model_flags.baseline_includes_uncertainty:
            output_dict["uncertainty"] = baseline[:, :, 1]
        return output_dict, core_state

    #
    #     # 使用父类的前向传播，并获取得到的动作嵌入
    #     output_dict, core_state = super().forward(inputs, action_space_id, core_state)
    #     embedding = output_dict["action"]
    #     embedding = torch.flatten(embedding, 0, 1)  # 将动作嵌入展平
    #     action_logit = self.action_decoder(embedding)
    #
    #     if self.training:
    #         # 训练时使用多项式分布采样动作
    #         action = torch.multinomial(F.softmax(action_logit, dim=1), num_samples=1)
    #     else:
    #         action = torch.argmax(action_logit, dim=1)
    #     action = action.view(T, B)
    #
    #     output_dict["action"] = action  # 将结果中的动作嵌入替换为解码后的动作
    #     return output_dict, core_state
